{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 🔪 Flax - The Sharp Bits 🔪\n",
    "\n",
    "Flax exposes the full power of JAX. And just like when using JAX, there are certain _[\"sharp bits\"](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html)_ you may experience when working with Flax. This evolving document is designed to assist you with them.\n",
    "\n",
    "First, install and/or update Flax:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "tags": [
     "skip-execution"
    ]
   },
   "outputs": [],
   "source": [
    "! pip install -qq flax"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 🔪 `flax.linen.Dropout` layer and randomness\n",
    "\n",
    "### TL;DR\n",
    "\n",
    "When working on a model with dropout (subclassed from [Flax `Module`](https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/flax_basics.html#module-basics)), add the `'dropout'` PRNGkey only during the forward pass.\n",
    "\n",
    "1. Start with [`jax.random.split()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.random.split.html#jax-random-split) to explicitly create PRNG keys for `'params'` and `'dropout'`.\n",
    "2. Add the [`flax.linen.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.Dropout.html#flax.linen.Dropout) layer(s) to your model (subclassed from Flax [`Module`](https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/flax_basics.html#module-basics)).\n",
    "3. When initializing the model ([`flax.linen.init()`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/init_apply.html)), there's no need to pass in an extra `'dropout'` PRNG key—just the `'params'` key like in a \"simpler\" model.\n",
    "4. During the forward pass with [`flax.linen.apply()`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/init_apply.html), pass in `rngs={'dropout': dropout_key}`.\n",
    "\n",
    "Check out a full example below.\n",
    "\n",
    "### Why this works\n",
    "\n",
    "- Internally, `flax.linen.Dropout` makes use of [`flax.linen.Module.make_rng`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.make_rng) to create a key for dropout (check out the [source code](https://github.com/google/flax/blob/5714e57a0dc8146eb58a7a06ed768ed3a17672f9/flax/linen/stochastic.py#L72)).\n",
    "- Every time `make_rng` is called (in this case, it's done implicitly in `Dropout`), you get a new PRNG key split from the main/root PRNG key.\n",
    "- `make_rng` still _guarantees full reproducibility_.\n",
    "\n",
    "### Background \n",
    "\n",
    "The [dropout](https://jmlr.org/papers/volume15/srivastava14a/srivastava14a.pdf) stochastic regularization technique randomly removes hidden and visible units in a network. Dropout is a random operation, requiring a PRNG state, and Flax (like JAX) uses [Threefry](https://github.com/google/jax/blob/main/docs/jep/263-prng.md) PRNG that is splittable. \n",
    "\n",
    "> Note: Recall that JAX has an explicit way of giving you PRNG keys: you can fork the main PRNG state (such as `key = jax.random.key(seed=0)`) into multiple new PRNG keys with `key, subkey = jax.random.split(key)`. Refresh your memory in [🔪 JAX - The Sharp Bits 🔪 Randomness and PRNG keys](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers).\n",
    "\n",
    "Flax provides an _implicit_ way of handling PRNG key streams via [Flax `Module`](https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/flax_basics.html#module-basics)'s [`flax.linen.Module.make_rng`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.make_rng) helper function. It allows the code in Flax `Module`s (or its sub-`Module`s) to \"pull PRNG keys\". `make_rng` guarantees to provide a unique key each time you call it.\n",
    "\n",
    "> Note: Recall that [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) is the base class for all neural network modules. All layers and models are subclassed from it.\n",
    "\n",
    "### Example\n",
    "\n",
    "Remember that each of the Flax PRNG streams has a name. The example below uses the `'params'` stream for initializing parameters, as well as the `'dropout'` stream. The PRNG key provided to [`flax.linen.init()`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/init_apply.html) is the one that seeds the `'params'` PRNG key stream. To draw PRNG keys during the forward pass (with dropout), provide a PRNG key to seed that stream (`'dropout'`) when you call `Module.apply()`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setup.\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import flax.linen as nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Randomness.\n",
    "seed = 0\n",
    "root_key = jax.random.key(seed=seed)\n",
    "main_key, params_key, dropout_key = jax.random.split(key=root_key, num=3)\n",
    "\n",
    "# A simple network.\n",
    "class MyModel(nn.Module):\n",
    "  num_neurons: int\n",
    "  training: bool\n",
    "  @nn.compact\n",
    "  def __call__(self, x):\n",
    "    x = nn.Dense(self.num_neurons)(x)\n",
    "    # Set the dropout layer with a rate of 50% .\n",
    "    # When the `deterministic` flag is `True`, dropout is turned off.\n",
    "    x = nn.Dropout(rate=0.5, deterministic=not self.training)(x)\n",
    "    return x\n",
    "\n",
    "# Instantiate `MyModel` (you don't need to set `training=True` to\n",
    "# avoid performing the forward pass computation).\n",
    "my_model = MyModel(num_neurons=3, training=False)\n",
    "\n",
    "x = jax.random.uniform(key=main_key, shape=(3, 4, 4))\n",
    "\n",
    "# Initialize with `flax.linen.init()`.\n",
    "# The `params_key` is equivalent to a dictionary of PRNGs.\n",
    "# (Here, you are providing only one PRNG key.) \n",
    "variables = my_model.init(params_key, x)\n",
    "\n",
    "# Perform the forward pass with `flax.linen.apply()`.\n",
    "y = my_model.apply(variables, x, rngs={'dropout': dropout_key})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Real-life examples:\n",
    "\n",
    "* Applying word dropout to a batch of input IDs (in a [text classification](https://github.com/google/flax/blob/main/examples/sst2/models.py) context).\n",
    "* Defining a prediction token in a decoder of a [sequence-to-sequence model](https://github.com/google/flax/blob/main/examples/seq2seq/models.py)."
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "jupytext": {
   "formats": "ipynb,md:myst"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
